{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U pip\n",
    "\n",
    "# If you don't have ClearML installed then uncomment this line\n",
    "# !pip install -U clearml>=0.16.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "e-YsQrBjzNdX"
   },
   "outputs": [],
   "source": [
    "!pip install -U torch==1.5.1\n",
    "!pip install -U torchaudio==0.5.1\n",
    "!pip install -U torchvision==0.6.1\n",
    "!pip install -U matplotlib==3.2.1\n",
    "!pip install -U pandas==1.0.4\n",
    "!pip install -U numpy==1.18.4\n",
    "!pip install -U tensorboard==2.2.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "T7T0Rf26zNdm"
   },
   "outputs": [],
   "source": [
    "import PIL\n",
    "import io\n",
    "\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from pathlib2 import Path\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.utils.data import Dataset\n",
    "from torch.utils.tensorboard import SummaryWriter\n",
    "\n",
    "import torchaudio\n",
    "from torchvision.transforms import ToTensor\n",
    "from torchvision import models\n",
    "\n",
    "from clearml import Task\n",
    "from clearml.storage import StorageManager\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "task = Task.init(project_name='Audio Example', task_name='audio classification UrbanSound8K')\n",
    "configuration_dict = {'number_of_epochs': 3, 'batch_size': 8, 'dropout': 0.3, 'base_lr': 0.005, \n",
    "                      'number_of_mel_filters': 64, 'resample_freq': 22050}\n",
    "configuration_dict = task.connect(configuration_dict)  # enabling configuration override by clearml\n",
    "print(configuration_dict)  # printing actual configuration (after override in remote mode)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "msiz7QdvzNeA",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# Download UrbanSound8K dataset (https://urbansounddataset.weebly.com/urbansound8k.html)\n",
    "# For simplicity we will use here a subset of that dataset using clearml StorageManager\n",
    "path_to_UrbanSound8K = StorageManager.get_local_copy(\"https://allegro-datasets.s3.amazonaws.com/clearml/UrbanSound8K.zip\", \n",
    "                                                     extract_archive=True)\n",
    "path_to_UrbanSound8K_csv = Path(path_to_UrbanSound8K) / 'UrbanSound8K' / 'metadata' / 'UrbanSound8K.csv'\n",
    "path_to_UrbanSound8K_audio = Path(path_to_UrbanSound8K) / 'UrbanSound8K' / 'audio'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "wXtmZe7yzNeS"
   },
   "outputs": [],
   "source": [
    "class UrbanSoundDataset(Dataset):\n",
    "    def __init__(self, csv_path, file_path, folderList, resample_freq=0, return_audio=False):\n",
    "        self.file_path = file_path\n",
    "        self.file_names = []\n",
    "        self.labels = []\n",
    "        self.folders = []\n",
    "        self.n_mels = configuration_dict.get('number_of_mel_filters', 64)\n",
    "        self.return_audio = return_audio\n",
    "        self.resample = resample_freq\n",
    "        \n",
    "        #loop through the csv entries and only add entries from folders in the folder list\n",
    "        csvData = pd.read_csv(csv_path)\n",
    "        for i in range(0,len(csvData)):\n",
    "            if csvData.iloc[i, 5] in folderList:\n",
    "                self.file_names.append(csvData.iloc[i, 0])\n",
    "                self.labels.append(csvData.iloc[i, 6])\n",
    "                self.folders.append(csvData.iloc[i, 5])\n",
    "        \n",
    "    def __getitem__(self, index):\n",
    "        #format the file path and load the file\n",
    "        path = self.file_path / (\"fold\" + str(self.folders[index])) / self.file_names[index]\n",
    "        soundData, sample_rate = torchaudio.load(path, out = None, normalization = True)\n",
    "\n",
    "        if self.resample > 0:\n",
    "            resample_transform = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.resample)\n",
    "            soundData = resample_transform(soundData)\n",
    "        \n",
    "        # This will convert audio files with two channels into one\n",
    "        soundData = torch.mean(soundData, dim=0, keepdim=True)\n",
    "               \n",
    "        # Convert audio to log-scale Mel spectrogram\n",
    "        melspectrogram_transform = torchaudio.transforms.MelSpectrogram(sample_rate=self.resample, n_mels=self.n_mels)\n",
    "        melspectrogram = melspectrogram_transform(soundData)\n",
    "        melspectogram_db = torchaudio.transforms.AmplitudeToDB()(melspectrogram)\n",
    "        \n",
    "        #Make sure all spectrograms are the same size\n",
    "        fixed_length = 3 * (self.resample//200)\n",
    "        if melspectogram_db.shape[2] < fixed_length:\n",
    "            melspectogram_db = torch.nn.functional.pad(melspectogram_db, (0, fixed_length - melspectogram_db.shape[2]))\n",
    "        else:\n",
    "            melspectogram_db = melspectogram_db[:, :, :fixed_length]\n",
    "        \n",
    "        if self.return_audio:\n",
    "            fixed_length = 3 * self.resample\n",
    "            if soundData.numel() < fixed_length:\n",
    "                soundData = torch.nn.functional.pad(soundData, (0, fixed_length - soundData.numel())).numpy()\n",
    "            else:\n",
    "                soundData = soundData[0,:fixed_length].reshape(1,fixed_length).numpy()\n",
    "        else:\n",
    "            soundData = np.array([])\n",
    "\n",
    "        return soundData, self.resample, melspectogram_db, self.labels[index]\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.file_names)\n",
    "\n",
    "train_set = UrbanSoundDataset(path_to_UrbanSound8K_csv, path_to_UrbanSound8K_audio, range(1,10), \n",
    "                              resample_freq=configuration_dict.get('resample_freq', 0), return_audio=False)\n",
    "test_set = UrbanSoundDataset(path_to_UrbanSound8K_csv, path_to_UrbanSound8K_audio, [10], \n",
    "                             resample_freq=configuration_dict.get('resample_freq', 0), return_audio=True)\n",
    "print(\"Train set size: \" + str(len(train_set)))\n",
    "print(\"Test set size: \" + str(len(test_set)))\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(train_set, batch_size = configuration_dict.get('batch_size', 4), \n",
    "                                           shuffle = True, pin_memory=True, num_workers=1)\n",
    "test_loader = torch.utils.data.DataLoader(test_set, batch_size = configuration_dict.get('batch_size', 4), \n",
    "                                          shuffle = False, pin_memory=False, num_workers=1)\n",
    "\n",
    "classes = ('air_conditioner', 'car_horn', 'children_playing', 'dog_bark', 'drilling', 'engine_idling', \n",
    "           'gun_shot', 'jackhammer', 'siren', 'street_music')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = models.resnet18(pretrained=True)\n",
    "model.conv1=nn.Conv2d(1, model.conv1.out_channels, kernel_size=model.conv1.kernel_size[0], \n",
    "                      stride=model.conv1.stride[0], padding=model.conv1.padding[0])\n",
    "num_ftrs = model.fc.in_features\n",
    "model.fc = nn.Sequential(*[nn.Dropout(p=configuration_dict.get('dropout', 0.25)), nn.Linear(num_ftrs, len(classes))])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3yKYru14zNef"
   },
   "outputs": [],
   "source": [
    "optimizer = optim.SGD(model.parameters(), lr = configuration_dict.get('base_lr', 0.001), momentum = 0.9)\n",
    "scheduler = optim.lr_scheduler.StepLR(optimizer, step_size = configuration_dict.get('number_of_epochs')//3, gamma = 0.1)\n",
    "criterion = nn.CrossEntropyLoss()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu')\n",
    "print('Device to use: {}'.format(device))\n",
    "model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tensorboard_writer = SummaryWriter('./tensorboard_logs')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_signal(signal, title, cmap=None):\n",
    "    fig = plt.figure()\n",
    "    if signal.ndim == 1:\n",
    "        plt.plot(signal)\n",
    "    else:\n",
    "        plt.imshow(signal, cmap=cmap)    \n",
    "    plt.title(title)\n",
    "    \n",
    "    plot_buf = io.BytesIO()\n",
    "    plt.savefig(plot_buf, format='jpeg')\n",
    "    plot_buf.seek(0)\n",
    "    plt.close(fig)\n",
    "    return ToTensor()(PIL.Image.open(plot_buf))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "Vdthqz3JzNem"
   },
   "outputs": [],
   "source": [
    "def train(model, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (sounds, sample_rate, inputs, labels) in enumerate(train_loader):\n",
    "        inputs = inputs.to(device)\n",
    "        labels = labels.to(device)\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = model(inputs)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        iteration = epoch * len(train_loader) + batch_idx\n",
    "        if batch_idx % log_interval == 0: #print training stats\n",
    "            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'\n",
    "                  .format(epoch, batch_idx * len(inputs), len(train_loader.dataset), \n",
    "                          100. * batch_idx / len(train_loader), loss))\n",
    "            tensorboard_writer.add_scalar('training loss/loss', loss, iteration)\n",
    "            tensorboard_writer.add_scalar('learning rate/lr', optimizer.param_groups[0]['lr'], iteration)\n",
    "                \n",
    "        \n",
    "        if batch_idx % debug_interval == 0:    # report debug image every \"debug_interval\" mini-batches\n",
    "            for n, (inp, pred, label) in enumerate(zip(inputs, predicted, labels)):\n",
    "                series = 'label_{}_pred_{}'.format(classes[label.cpu()], classes[pred.cpu()])\n",
    "                tensorboard_writer.add_image('Train MelSpectrogram samples/{}_{}_{}'.format(batch_idx, n, series), \n",
    "                                             plot_signal(inp.cpu().numpy().squeeze(), series, 'hot'), iteration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "LBWoj7u5zNes"
   },
   "outputs": [],
   "source": [
    "def test(model, epoch):\n",
    "    model.eval()\n",
    "    class_correct = list(0. for i in range(len(classes)))\n",
    "    class_total = list(0. for i in range(len(classes)))\n",
    "    with torch.no_grad():\n",
    "        for idx, (sounds, sample_rate, inputs, labels) in enumerate(test_loader):\n",
    "            inputs = inputs.to(device)\n",
    "            labels = labels.to(device)\n",
    "\n",
    "            outputs = model(inputs)\n",
    "\n",
    "            _, predicted = torch.max(outputs, 1)\n",
    "            c = (predicted == labels)\n",
    "            for i in range(len(inputs)):\n",
    "                label = labels[i].item()\n",
    "                class_correct[label] += c[i].item()\n",
    "                class_total[label] += 1\n",
    "        \n",
    "            iteration = (epoch + 1) * len(train_loader)\n",
    "            if idx % debug_interval == 0:    # report debug image every \"debug_interval\" mini-batches\n",
    "                for n, (sound, inp, pred, label) in enumerate(zip(sounds, inputs, predicted, labels)):\n",
    "                    series = 'label_{}_pred_{}'.format(classes[label.cpu()], classes[pred.cpu()])\n",
    "                    tensorboard_writer.add_audio('Test audio samples/{}_{}_{}'.format(idx, n, series), \n",
    "                                                 sound, iteration, int(sample_rate[n]))\n",
    "                    tensorboard_writer.add_image('Test MelSpectrogram samples/{}_{}_{}'.format(idx, n, series), \n",
    "                                                 plot_signal(inp.cpu().numpy().squeeze(), series, 'hot'), iteration)\n",
    "\n",
    "    total_accuracy = 100 * sum(class_correct)/sum(class_total)\n",
    "    print('[Iteration {}] Accuracy on the {} test images: {}%\\n'.format(epoch, sum(class_total), total_accuracy))\n",
    "    tensorboard_writer.add_scalar('accuracy/total', total_accuracy, iteration)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "X5lx3g_5zNey"
   },
   "outputs": [],
   "source": [
    "log_interval = 10\n",
    "debug_interval = 25\n",
    "for epoch in range(configuration_dict.get('number_of_epochs', 10)):\n",
    "    train(model, epoch)\n",
    "    test(model, epoch)\n",
    "    scheduler.step()"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "audio_classifier_tutorial.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
